{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Tutorial 1: Ring Tensor\n",
    "In NssMPClib, the basic data structure is ```RingTensor```, which means the tensor on the ring, corresponding to the tensor of ```torch```. We transform the tensor of ```torch``` to the ```RingTensor``` and perform operations on it, which supports multiple data types of ```torch``` (int64, int32, float64, float32).\n",
    "Now let's start by importing the RingTensor package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:09.108379Z",
     "start_time": "2024-03-21T11:20:08.495246Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "from NssMPC import RingTensor\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "### Conversion between torch tensors and RingTensor\n",
    "The lib provides a ``convert_to_ring`` method to convert a tensor from ``torch.Tensor`` to ``RingTensor``. The ```convert_to_real_field``` method converts a ``RingTensor`` data to a ``torch.Tensor`` type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:09.139254Z",
     "start_time": "2024-03-21T11:20:09.110437Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RingTensor\n",
      " value:tensor([ 65536, 131072, 196608]) \n",
      " dtype:float \n",
      " scale:65536\n",
      "tensor([1., 2., 3.], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Create torch tensor\n",
    "x = torch.tensor([1.0, 2.0, 3.0])\n",
    "\n",
    "# Convert a tensor to ring\n",
    "x_on_ring = RingTensor.convert_to_ring(x)\n",
    "print(x_on_ring)\n",
    "\n",
    "# Convert a RingTensor to real field\n",
    "x_real_field = x_on_ring.convert_to_real_field()\n",
    "print(x_real_field)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "This lib supports the ` ` ` torch.int64 ` ` ` , ` ` ` torch.int32 ` ` ` , ` ` ` torch.float64 ` ` ` and ` ` ` torch.float32 ` ` ` type of data conversion to the ring."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:09.154667Z",
     "start_time": "2024-03-21T11:20:09.141258Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RingTensor\n",
      " value:tensor([1, 2, 3]) \n",
      " dtype:int \n",
      " scale:1\n",
      "RingTensor\n",
      " value:tensor([1, 2, 3]) \n",
      " dtype:int \n",
      " scale:1\n",
      "RingTensor\n",
      " value:tensor([ 65536, 131072, 196608]) \n",
      " dtype:float \n",
      " scale:65536\n",
      "RingTensor\n",
      " value:tensor([ 65536, 131072, 196608]) \n",
      " dtype:float \n",
      " scale:65536\n"
     ]
    }
   ],
   "source": [
    "# Convert different data type tensor to ring\n",
    "# torch.int64\n",
    "x_int64 = torch.tensor([1, 2, 3], dtype=torch.int64)\n",
    "\n",
    "x_int64_on_ring = RingTensor.convert_to_ring(x_int64)\n",
    "print(x_int64_on_ring)\n",
    "\n",
    "# torch.int32\n",
    "x_int32 = torch.tensor([1, 2, 3], dtype=torch.int32)\n",
    "\n",
    "x_int32_on_ring = RingTensor.convert_to_ring(x_int32)\n",
    "print(x_int32_on_ring)\n",
    "\n",
    "# torch.float64\n",
    "x_float64 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64)\n",
    "\n",
    "x_float64_on_ring = RingTensor.convert_to_ring(x_float64)\n",
    "print(x_float64_on_ring)\n",
    "\n",
    "# torch.float32\n",
    "x_float32 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)\n",
    "\n",
    "x_float32_on_ring = RingTensor.convert_to_ring(x_float32)\n",
    "print(x_float32_on_ring)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Operations on Ring tensor\n",
    "Now let's look at what we can do with our ```RingTensor```\n",
    "\n",
    "#### Arithmetic Operations\n",
    "We can carry out regular arithmetic operations between ```RingTensors```. These operations return a ring tensor output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:09.170189Z",
     "start_time": "2024-03-21T11:20:09.155680Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Addition: tensor([3., 4., 5.], dtype=torch.float64)\n",
      "\n",
      "Subtraction tensor([-1.,  0.,  1.], dtype=torch.float64)\n",
      "\n",
      "Multiplication tensor([2., 4., 6.], dtype=torch.float64)\n",
      "\n",
      "Matrix Multiplication tensor([14.], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Arithmetic operations between RingTensors\n",
    "x_on_ring = RingTensor.convert_to_ring(torch.tensor([1.0, 2.0, 3.0]))\n",
    "\n",
    "y_on_ring = RingTensor.convert_to_ring(torch.tensor([2.0]))\n",
    "\n",
    "\n",
    "# Addition\n",
    "res_on_ring = x_on_ring + y_on_ring\n",
    "print(\"\\nAddition:\", res_on_ring.convert_to_real_field())\n",
    "\n",
    "# Subtraction\n",
    "res_on_ring = x_on_ring - y_on_ring\n",
    "print(\"\\nSubtraction\", res_on_ring.convert_to_real_field())\n",
    "\n",
    "# Multiplication\n",
    "res_on_ring = x_on_ring * y_on_ring\n",
    "print(\"\\nMultiplication\", res_on_ring.convert_to_real_field())\n",
    "\n",
    "# Matrix Multiplication\n",
    "y_on_ring = RingTensor.convert_to_ring(torch.tensor([[1.0], [2.0], [3.0]]))\n",
    "res_on_ring = x_on_ring @ y_on_ring\n",
    "print(\"\\nMatrix Multiplication\", res_on_ring.convert_to_real_field())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "#### Comparisons\n",
    "Similarly, we can compute element-wise comparisons on ```RingTensors```. Different from arithmetic operations, comparisons performed on ```RingTensors``` will return ```True``` or ```False```, which is like comparisons between ```torch``` tensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:09.185787Z",
     "start_time": "2024-03-21T11:20:09.172243Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "(x < y) : RingTensor\n",
      " value:tensor([ True,  True, False, False, False]) \n",
      " dtype:float \n",
      " scale:65536\n",
      "\n",
      "(x <= y) : RingTensor\n",
      " value:tensor([ True,  True,  True, False, False]) \n",
      " dtype:float \n",
      " scale:65536\n",
      "\n",
      "(x > y) : RingTensor\n",
      " value:tensor([False, False, False,  True,  True]) \n",
      " dtype:float \n",
      " scale:65536\n",
      "\n",
      "(x >= y) : RingTensor\n",
      " value:tensor([False, False,  True,  True,  True]) \n",
      " dtype:float \n",
      " scale:65536\n",
      "\n",
      "(x == y) : RingTensor\n",
      " value:tensor([False, False,  True, False, False]) \n",
      " dtype:float \n",
      " scale:65536\n",
      "\n",
      "(x != y) : RingTensor\n",
      " value:tensor([ True,  True, False,  True,  True]) \n",
      " dtype:float \n",
      " scale:65536\n"
     ]
    }
   ],
   "source": [
    "#Comparisons between RingTensors\n",
    "x_on_ring = RingTensor.convert_to_ring(torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]))\n",
    "\n",
    "y_on_ring = RingTensor.convert_to_ring(torch.tensor([5.0, 4.0, 3.0, 2.0, 1.0]))\n",
    "\n",
    "# Less than\n",
    "result = x_on_ring < y_on_ring\n",
    "print(\"\\n(x < y) :\", result)\n",
    "\n",
    "# Less than or equal\n",
    "result = x_on_ring <= y_on_ring\n",
    "print(\"\\n(x <= y) :\", result)\n",
    "\n",
    "# Greater than\n",
    "result = x_on_ring > y_on_ring\n",
    "print(\"\\n(x > y) :\", result)\n",
    "\n",
    "# Greater than or equal\n",
    "result = x_on_ring >= y_on_ring\n",
    "print(\"\\n(x >= y) :\", result)\n",
    "\n",
    "# Equal\n",
    "result = x_on_ring == y_on_ring\n",
    "print(\"\\n(x == y) :\", result)\n",
    "\n",
    "# Not Equal\n",
    "result = x_on_ring != y_on_ring\n",
    "print(\"\\n(x != y) :\", result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "#### Other operations\n",
    "The library also supports some other operations on RingTensor, such as reshape, stack, cat, unsqueeze and so on.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:10.724401Z",
     "start_time": "2024-03-21T11:20:09.187867Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Concatenation: \n",
      " tensor([1., 2., 3., 4., 5., 6.], dtype=torch.float64)\n",
      "\n",
      "Concatenation: \n",
      " tensor([[1., 2., 3.],\n",
      "        [4., 5., 6.]], dtype=torch.float64)\n",
      "\n",
      "Reshape: \n",
      " tensor([[1., 2., 3., 4., 5., 6.]], dtype=torch.float64)\n",
      "\n",
      "UnSqueezing: \n",
      " tensor([[1.],\n",
      "        [2.],\n",
      "        [3.]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "x_on_ring = RingTensor.convert_to_ring(torch.tensor([1.0, 2.0, 3.0]))\n",
    "y_on_ring = RingTensor.convert_to_ring(torch.tensor([4.0, 5.0, 6.0]))\n",
    "\n",
    "# Concatenation\n",
    "res_on_ring = RingTensor.cat((x_on_ring, y_on_ring))\n",
    "print(\"Concatenation: \\n\", res_on_ring.convert_to_real_field())\n",
    "\n",
    "# Stacking\n",
    "res_on_ring = RingTensor.stack((x_on_ring, y_on_ring))\n",
    "print(\"\\nConcatenation: \\n\", res_on_ring.convert_to_real_field())\n",
    "\n",
    "# Reshaping\n",
    "res_on_ring_after_reshape = res_on_ring.reshape(-1, 6)\n",
    "print(\"\\nReshape: \\n\", res_on_ring_after_reshape.convert_to_real_field())\n",
    "\n",
    "# UnSqueezing\n",
    "res_on_ring = x_on_ring.unsqueeze(dim=1)\n",
    "print(\"\\nUnSqueezing: \\n\", res_on_ring.convert_to_real_field())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Some functions like where, random, arange is also supported in the RingTensor, but the usage is a little different from above. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Where: \n",
      " tensor([4., 5., 6.], dtype=torch.float64)\n",
      "\n",
      "Random: \n",
      " tensor([[3, 7, 3],\n",
      "        [4, 7, 7]])\n",
      "\n",
      "Arange: \n",
      " tensor([0, 2, 4, 6, 8])\n"
     ]
    }
   ],
   "source": [
    "from NssMPC.common.ring.ring_tensor import RingTensor\n",
    "\n",
    "# Where\n",
    "x_on_ring = RingTensor.convert_to_ring(torch.tensor([1.0, 5.0, 3.0]))\n",
    "y_on_ring = RingTensor.convert_to_ring(torch.tensor([4.0, 2.0, 6.0]))\n",
    "condition = x_on_ring > y_on_ring\n",
    "\n",
    "res_on_ring = RingTensor.where(condition, x_on_ring, y_on_ring)\n",
    "print(\"\\nWhere: \\n\", res_on_ring.convert_to_real_field())\n",
    "\n",
    "# Random\n",
    "res_on_ring = RingTensor.random(shape=(2, 3), dtype='int', device='cpu', down_bound=0, upper_bound=10)\n",
    "print(\"\\nRandom: \\n\", res_on_ring.convert_to_real_field())\n",
    "\n",
    "# Arange\n",
    "res_on_ring = RingTensor.arange(start=0, end=10, step=2, dtype='int', device='cpu')\n",
    "print(\"\\nArange: \\n\", res_on_ring.convert_to_real_field())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Note that the condition of the where function must be a RingTensor, and the rest of the arguments can be RingTensor or int."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
